// Copyright lowRISC contributors.
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0

#include "sw/device/lib/arch/device.h"
#include "sw/device/lib/base/mmio.h"
#include "sw/device/lib/dif/dif_otbn.h"
#include "sw/device/lib/runtime/hart.h"
#include "sw/device/lib/runtime/ibex.h"
#include "sw/device/lib/runtime/log.h"
#include "sw/device/lib/testing/check.h"
#include "sw/device/lib/testing/test_main.h"

#include "hw/top_earlgrey/sw/autogen/top_earlgrey.h"

/**
 * Information about an embedded OTBN application image.
 *
 * All pointers reference data in the normal CPU address space.
 */
typedef struct otbn_app {
  /**
   * Start of OTBN instruction memory.
   */
  const char *imem_start;
  /**
   * End of OTBN instruction memory.
   */
  const char *imem_end;
  /**
   * Start of OTBN data memory.
   */
  const char *dmem_start;
  /**
   * End of OTBN data memory.
   */
  const char *dmem_end;
} otbn_app_t;

/**
 * A function (entry point) in an OTBN application.
 */
typedef struct otbn_func {
  /**
   * Application associated with this function.
   */
  const otbn_app_t *app;
  /**
   * Address of the function entry point in normal CPU address space.
   */
  const char *func;
} otbn_func_t;

// Prepare info structs representing the barrett384 application, and the
// wrap_barrett384() function. The goal of these structs is to make the symbols
// for the application image available in a form that is easier to use.
//
// The symbols are generated by the linker script (the start/end symbols for
// instruction and data memories), as well as by the assembly sources (all
// global symbols, see `sw/otbn/code-snippets/barrett384.s`).
// Symbols are then prefixed with "_otbn_app_APPNAME_" by objcopy, as called
// from `util/otbn_build.py`.
//
// TODO: Think about macro-ifying or auto-generating some of this code in a
// future OTBN driver.
extern const char _otbn_app_barrett384__imem_start[];
extern const char _otbn_app_barrett384__imem_end[];
extern const char _otbn_app_barrett384__dmem_start[];
extern const char _otbn_app_barrett384__dmem_end[];
extern const char _otbn_app_barrett384_wrap_barrett384[];

static const otbn_app_t kOtbnAppBarrett384 = {
    .imem_start = _otbn_app_barrett384__imem_start,
    .imem_end = _otbn_app_barrett384__imem_end,
    .dmem_start = _otbn_app_barrett384__dmem_start,
    .dmem_end = _otbn_app_barrett384__dmem_end,
};

static const otbn_func_t kOtbnFuncBarrett384WrapBarrett384 = {
    .app = &kOtbnAppBarrett384,
    .func = _otbn_app_barrett384_wrap_barrett384,
};

const test_config_t kTestConfig = {
    .can_clobber_uart = false,
};

static dif_otbn_t otbn;

/*
// Dump OTBN's memory. A development helper, unused otherwise.
static void otbn_dump_dmem() {
  uint32_t data[8];
  for (int i = 0; i < dif_otbn_get_dmem_size_bytes(&otbn) / 32; ++i) {
    dif_otbn_dmem_read(&otbn, i * 32, data, 32);
    LOG_INFO("DMEM @%04d: 0x%08x%08x%08x%08x%08x%08x%08x%08x", i, data[7],
             data[6], data[5], data[4], data[3], data[2], data[1], data[0]);
  }
}
*/

/**
 * Load an application into OTBN
 *
 * Load the text and data segments into the instruction and data memories,
 * respectively.
 *
 * @param app OTBN application to load
 */
static void load_app(const otbn_app_t *app) {
  CHECK(app->imem_end >= app->imem_start);
  const size_t imem_size = app->imem_end - app->imem_start;

  CHECK(app->dmem_end >= app->dmem_start);
  const size_t dmem_size = app->dmem_end - app->dmem_start;

  LOG_INFO(
      "Loading OTBN instruction memory image stored between address 0x%x and "
      "%p (%d bytes)",
      app->imem_start, app->imem_end, imem_size);
  CHECK(imem_size > 0);
  CHECK(imem_size % 4 == 0);
  CHECK(dif_otbn_imem_write(&otbn, 0, app->imem_start, imem_size) == kDifOtbnOk,
        "Unable to write IMEM application image (.text) to OTBN.");

  if (dmem_size > 0) {
    LOG_INFO(
        "Loading OTBN data memory image stored between address %p and %p (%d "
        "bytes)",
        app->dmem_start, app->dmem_end, dmem_size);
    CHECK(
        dif_otbn_dmem_write(&otbn, 0, app->dmem_start, dmem_size) == kDifOtbnOk,
        "Unable to write DMEM application image (.data) to OTBN.");
  } else {
    LOG_INFO("No OTBN data memory image to load.");
  }
}

/**
 * Call a function on OTBN
 *
 * Set the entry point (start address) of OTBN to the desired function, and
 * starts the OTBN operation.
 *
 * @param func the function to be called
 */
static void call_function(const otbn_func_t *func) {
  uint32_t start_address = func->func - func->app->imem_start;
  LOG_INFO("Calling function at address 0x%x on OTBN.", start_address);
  CHECK(dif_otbn_start(&otbn, start_address) == kDifOtbnOk);
}

/**
 * Busy wait for OTBN to be done with its operation.
 */
static void otbn_wait_for_done(void) {
  bool busy = true;
  while (busy) {
    CHECK(dif_otbn_is_busy(&otbn, &busy) == kDifOtbnOk,
          "Unable to get busy status from OTBN");
  }
}

/**
 * Initialize OTBN's data memory with zeros
 */
static void zero_dmem(void) {
  int dmem_size_words = dif_otbn_get_dmem_size_bytes(&otbn) / sizeof(uint32_t);
  for (int i = 0; i < dmem_size_words; ++i) {
    const uint32_t zero = 0;
    dif_otbn_result_t rv =
        dif_otbn_dmem_write(&otbn, i * sizeof(uint32_t), &zero, sizeof(zero));
    CHECK(rv == kDifOtbnOk, "Error zeroing word %d in OTBN DMEM: %d", i, rv);
  }
}

/**
 * Run a 384-bit Barrett Multiplication on OTBN and check its result.
 *
 * This test is not aiming to exhaustively test the Barrett multiplication
 * itself, but test the interaction between device software and OTBN. As such,
 * only trivial parameters are used.
 *
 * The code executed on OTBN can be found in sw/otbn/code-snippets/barrett384.s.
 * The entry point wrap_barrett384() is called according to the calling
 * convention described in the OTBN assembly code file.
 */
static void test_barrett384(void) {
  enum { kDataSizeBytes = 48 };

  zero_dmem();

  load_app(&kOtbnAppBarrett384);

  // a, first operand
  static const uint8_t a[kDataSizeBytes] = {10};

  // b, second operand
  static uint8_t b[kDataSizeBytes] = {20};

  // m, modulus, max. length 384 bit with 2^384 > m > 2^383
  // We choose the modulus of P-384: m = 2**384 - 2**128 - 2**96 + 2**32 - 1
  static const uint8_t m[kDataSizeBytes] = {
      0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
      0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
      0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe, 0xff, 0xff, 0xff, 0xff,
      0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff};

  // u, pre-computed Barrett constant (without u[384]/MSb of u which is always 1
  // for the allowed range but has to be set to 0 here).
  // u has to be pre-calculated as u = floor(2^768/m).
  static const uint8_t u[kDataSizeBytes] = {
      0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
      0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
      0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00,
      0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01};

  // c, result, max. length 384 bit.
  uint8_t c[kDataSizeBytes] = {0};

  // c = (a * b) % m = (10 * 20) % m = 200
  static const uint8_t c_expected[kDataSizeBytes] = {200};

  LOG_INFO("Writing input arguments to DMEM");
  // TODO: Use symbols from the application to load these parameters once they
  // are available (#3998).
  CHECK(dif_otbn_dmem_write(&otbn, /*offset_bytes=*/0, &a, sizeof(a)) ==
        kDifOtbnOk);
  CHECK(dif_otbn_dmem_write(&otbn, /*offset_bytes=*/64, &b, sizeof(b)) ==
        kDifOtbnOk);
  CHECK(dif_otbn_dmem_write(&otbn, /*offset_bytes=*/256, &m, sizeof(m)) ==
        kDifOtbnOk);
  CHECK(dif_otbn_dmem_write(&otbn, /*offset_bytes=*/320, &u, sizeof(u)) ==
        kDifOtbnOk);

  int t_start = ibex_mcycle_read();

  LOG_INFO("Calling wrap_barrett384()");
  call_function(&kOtbnFuncBarrett384WrapBarrett384);

  otbn_wait_for_done();

  int t_end = ibex_mcycle_read();
  LOG_INFO("Function execution on OTBN took %d cycles (end-to-end).",
           t_end - t_start);

  LOG_INFO("Reading back result (c)");
  dif_otbn_dmem_read(&otbn, 512, &c, sizeof(c));

  for (int i = 0; i < sizeof(c); ++i) {
    CHECK(c[i] == c_expected[i],
          "Unexpected result c at byte %d: 0x%x (actual) != 0x%x (expected)", i,
          c[i], c_expected[i]);
  }
}

bool test_main() {
  LOG_INFO("Running OTBN DIF test");

  dif_otbn_config_t otbn_config = {
      .base_addr = mmio_region_from_addr(TOP_EARLGREY_OTBN_BASE_ADDR),
  };
  dif_otbn_result_t rv = dif_otbn_init(&otbn_config, &otbn);
  CHECK(rv == kDifOtbnOk, "dif_otbn_init() failed: %d", rv);

  LOG_INFO("Running barrett384 code on OTBN");
  test_barrett384();

  return true;
}
